#!/usr/bin/env python
#-*-coding:utf-8-*-

import os
import sys
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_dir)

import time
import datetime
import numpy as np
import tensorflow as tf
from tensorflow.contrib import learn

from model import TextCNN
from utils import data_utils, train_utils, word2vec

# Model Hyperparameters
tf.flags.DEFINE_integer("embedding_dim", 100, "Dimensionality of character embedding (default: 128)")
tf.flags.DEFINE_string("filter_sizes", "3,4,5", "Comma-separated filter sizes (default: '3,4,5')")
tf.flags.DEFINE_integer("num_filters", 64, "Number of filters per filter size (default: 128)")
tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)")
tf.flags.DEFINE_float("l2_lambda", 0.02, "L2 regularizaion lambda (default: 0.0)")
tf.flags.DEFINE_float("learning_rate", 0.01, "learning rate")

# Training parameters
tf.flags.DEFINE_string("train_data_path", "", "train data")
tf.flags.DEFINE_float("train_percent", 0.9, "the percent of train data")
tf.flags.DEFINE_integer("num_classes", 4, "the num of classes")

tf.flags.DEFINE_string("pre_embedding", True, "whether to use pre embedding or not.")
tf.flags.DEFINE_string("embed_file", "", "embed file")


tf.flags.DEFINE_integer("batch_size", 512, "Batch Size (default: 64)")
tf.flags.DEFINE_integer("num_epochs", 200, "Number of training epochs (default: 200)")

tf.flags.DEFINE_integer("evaluate_every", 80, "Evaluate model on dev set after this many steps (default: 100)")
tf.flags.DEFINE_integer("checkpoint_every", 80, "Save model after this many steps (default: 100)")


tf.flags.DEFINE_integer("max_sent_length", 20, "max sentence length")
tf.flags.DEFINE_integer("decay_steps", 10,  "how many steps before decay learning rate")
tf.flags.DEFINE_float("decay_rate", 0.95, "Rate of decay for learning rate")

# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")


FLAGS = tf.flags.FLAGS
FLAGS._parse_flags()

print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
    print("{}={}".format(attr.upper(), value))
print("")


print("Loading data...")
x_text, y = data_utils.load_fastext_train_data(FLAGS.train_data_path)

vocab_processor = learn.preprocessing.VocabularyProcessor(FLAGS.max_sent_length)
x = np.array(list(vocab_processor.fit_transform(x_text)))
y = data_utils.label_to_array(y, FLAGS.num_classes)


np.random.seed(10)
shuffle_indices = np.random.permutation(np.arange(len(y)))
x_shuffled = x[shuffle_indices]
y_shuffled = y[shuffle_indices]


sample_count = len(x_text)
train_count = int(sample_count * FLAGS.train_percent)
x_train, x_dev = x_shuffled[: train_count], x_shuffled[train_count: ]
y_train, y_dev = y_shuffled[: train_count], y_shuffled[train_count: ]

print("Vocabulary Size: {:d}".format(len(vocab_processor.vocabulary_)))
print("Train/Dev split: {:d}/{:d}".format(len(y_train), len(y_dev)))


# training
with tf.Graph().as_default():

    session_conf = tf.ConfigProto(
        allow_soft_placement = FLAGS.allow_soft_placement,
        log_device_placement = FLAGS.log_device_placement
    )

    sess = tf.Session(config = session_conf)
    with sess.as_default():

        model = TextCNN(
            embedding_size = FLAGS.embedding_dim,
            sequence_length = FLAGS.max_sent_length,
            num_filters = FLAGS.num_filters,
            filter_sizes = list(map(int, FLAGS.filter_sizes.split(","))),
            decay_steps = FLAGS.decay_steps,
            decay_rate = FLAGS.decay_rate,
            vocab_size = len(vocab_processor.vocabulary_),
            is_training = True,
            learning_rate = FLAGS.learning_rate,
            num_classes = FLAGS.num_classes,
            l2_lambda = 0.01
        )

        timestamp = str(int(time.time()))
        out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
        checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
        checkpoint_prefix = os.path.join(checkpoint_dir, "model")

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        saver = tf.train.Saver(tf.all_variables())

        print("save vocab ... ")
        vocab_processor.save(os.path.join(out_dir, "vocab"))

        # Initialize all variables
        sess.run(tf.initialize_all_variables())

        if FLAGS.pre_embedding:
            print("load pre word2vec ...")
            wv = word2vec.Word2vec()
            embed = wv.load_w2v_array(FLAGS.embed_file, vocab_processor, )
            word_embedding = tf.constant(embed, dtype=tf.float32)
            t_assign_embedding = tf.assign(model.Embedding, word_embedding)
            sess.run(t_assign_embedding)


        def train_step(x_batch, y_batch):
            """
            A single training step
            """
            feed_dict = {
                model.input_x: x_batch,
                model.input_y: y_batch,
                model.dropout_keep_prob: FLAGS.dropout_keep_prob
            }
            _, step, loss, accuracy = sess.run(
                [model.train_op, model.global_step, model.loss_val, model.accuracy],
                feed_dict)

            time_str = datetime.datetime.now().isoformat()
            print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
            return step


        def dev_step(x_batch, y_batch, writer=None):
            """
            Evaluates model on a dev set
            """

            feed_dict = {
                model.input_x: x_batch,
                model.input_y: y_batch,
                model.dropout_keep_prob: 1.0
            }

            step, loss, accuracy = sess.run(
                [model.global_step, model.loss_val, model.accuracy],
                feed_dict)

            time_str = datetime.datetime.now().isoformat()
            print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
            return accuracy


        batches = train_utils.batch_iter(list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)

        for batch in batches:
            x_batch, y_batch = zip(*batch)
            current_step = train_step(x_batch, y_batch)

            if current_step % FLAGS.evaluate_every == 0:
                print("\nEvaluation:")
                dev_acc = dev_step(x_dev, y_dev, )
                print("")
                if dev_acc > 0.95:
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    print("Saved model checkpoint to {}\n".format(path))
                    print("finish train")
                    exit(0)

            if current_step % FLAGS.checkpoint_every == 0:
                path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                print("Saved model checkpoint to {}\n".format(path))
            sess.run(model.epoch_increment)